package mcp

import (
	"context"
	"fmt"
	"sync"
	"time"

	"github.com/Tencent/WeKnora/internal/logger"
	"github.com/Tencent/WeKnora/internal/types"
)

// MCPManager manages MCP client connections
type MCPManager struct {
	clients   map[string]MCPClient // serviceID -> client
	clientsMu sync.RWMutex
	ctx       context.Context
	cancel    context.CancelFunc
}

// NewMCPManager creates a new MCP manager
func NewMCPManager() *MCPManager {
	ctx, cancel := context.WithCancel(context.Background())

	manager := &MCPManager{
		clients: make(map[string]MCPClient),
		ctx:     ctx,
		cancel:  cancel,
	}

	// Start cleanup goroutine
	go manager.cleanupIdleConnections()

	return manager
}

// GetOrCreateClient gets an existing client or creates a new one
// For stdio transport, always creates a new client (not cached)
// For SSE/HTTP Streamable, caches and reuses existing connections
func (m *MCPManager) GetOrCreateClient(service *types.MCPService) (MCPClient, error) {
	// Check if service is enabled
	if !service.Enabled {
		return nil, fmt.Errorf("MCP service %s is not enabled", service.Name)
	}

	// For stdio transport, always create a new client (don't cache)
	if service.TransportType == types.MCPTransportStdio {
		return m.createStdioClient(service)
	}

	// For SSE/HTTP Streamable, check if client already exists and reuse
	m.clientsMu.RLock()
	client, exists := m.clients[service.ID]
	m.clientsMu.RUnlock()

	if exists && client.IsConnected() {
		return client, nil
	}

	// Create new client
	m.clientsMu.Lock()
	defer m.clientsMu.Unlock()

	// Double check after acquiring write lock
	client, exists = m.clients[service.ID]
	if exists && client.IsConnected() {
		return client, nil
	}

	// Create new client
	config := &ClientConfig{
		Service: service,
	}

	client, err := NewMCPClient(config)
	if err != nil {
		return nil, fmt.Errorf("failed to create MCP client: %w", err)
	}

	// For SSE connections, Connect() starts a persistent connection that needs a long-lived context
	// Use manager's context (m.ctx) which persists for the lifetime of the manager
	// The HTTP client's timeout will handle connection timeouts, not context cancellation
	if err := client.Connect(m.ctx); err != nil {
		return nil, fmt.Errorf("failed to connect to MCP service: %w", err)
	}

	if err := m.initializeClient(service, client, "failed to initialize MCP client"); err != nil {
		return nil, err
	}

	// Store client (only for non-stdio transports)
	m.clients[service.ID] = client

	logger.GetLogger(m.ctx).Infof("MCP client created and initialized for service: %s", service.Name)
	return client, nil
}

// createStdioClient creates a new stdio client (not cached)
func (m *MCPManager) createStdioClient(service *types.MCPService) (MCPClient, error) {
	// Create new client
	config := &ClientConfig{
		Service: service,
	}

	client, err := NewMCPClient(config)
	if err != nil {
		return nil, fmt.Errorf("failed to create stdio MCP client: %w", err)
	}

	// For stdio, Connect() starts the subprocess
	// Use manager's context for the connection lifecycle
	if err := client.Connect(m.ctx); err != nil {
		return nil, fmt.Errorf("failed to connect to stdio MCP service: %w", err)
	}

	if err := m.initializeClient(service, client, "failed to initialize stdio MCP client"); err != nil {
		return nil, err
	}

	logger.GetLogger(m.ctx).Infof("MCP stdio client created and initialized for service: %s", service.Name)
	return client, nil
}

// initializeClient handles the shared initialization flow with timeout enforcement.
func (m *MCPManager) initializeClient(service *types.MCPService, client MCPClient, errPrefix string) error {
	initTimeout := 30 * time.Second
	if service.AdvancedConfig != nil && service.AdvancedConfig.Timeout > 0 {
		initTimeout = time.Duration(service.AdvancedConfig.Timeout) * time.Second
		if initTimeout > 60*time.Second {
			initTimeout = 60 * time.Second
		}
	}

	initCtx, initCancel := context.WithTimeout(m.ctx, initTimeout)
	defer initCancel()

	if _, err := client.Initialize(initCtx); err != nil {
		client.Disconnect()
		if errPrefix == "" {
			errPrefix = "failed to initialize MCP client"
		}
		return fmt.Errorf("%s: %w", errPrefix, err)
	}

	return nil
}

// GetClient gets an existing client
func (m *MCPManager) GetClient(serviceID string) (MCPClient, bool) {
	m.clientsMu.RLock()
	defer m.clientsMu.RUnlock()

	client, exists := m.clients[serviceID]
	return client, exists
}

// CloseClient closes and removes a specific client
func (m *MCPManager) CloseClient(serviceID string) error {
	m.clientsMu.Lock()
	defer m.clientsMu.Unlock()

	client, exists := m.clients[serviceID]
	if !exists {
		return nil
	}

	if err := client.Disconnect(); err != nil {
		logger.GetLogger(m.ctx).Errorf("Failed to disconnect MCP client %s: %v", serviceID, err)
	}

	delete(m.clients, serviceID)
	logger.GetLogger(m.ctx).Infof("MCP client closed: %s", serviceID)
	return nil
}

// CloseAll closes all clients
func (m *MCPManager) CloseAll() {
	m.clientsMu.Lock()
	defer m.clientsMu.Unlock()

	for serviceID, client := range m.clients {
		if err := client.Disconnect(); err != nil {
			logger.GetLogger(m.ctx).Errorf("Failed to disconnect MCP client %s: %v", serviceID, err)
		}
	}

	m.clients = make(map[string]MCPClient)
	logger.GetLogger(m.ctx).Info("All MCP clients closed")
}

// Shutdown gracefully shuts down the manager
func (m *MCPManager) Shutdown() {
	m.cancel()
	m.CloseAll()
}

// cleanupIdleConnections periodically cleans up disconnected clients
func (m *MCPManager) cleanupIdleConnections() {
	ticker := time.NewTicker(5 * time.Minute)
	defer ticker.Stop()

	for {
		select {
		case <-m.ctx.Done():
			return
		case <-ticker.C:
			m.removeDisconnectedClients()
		}
	}
}

// removeDisconnectedClients removes clients that are no longer connected
func (m *MCPManager) removeDisconnectedClients() {
	m.clientsMu.Lock()
	defer m.clientsMu.Unlock()

	for serviceID, client := range m.clients {
		if !client.IsConnected() {
			delete(m.clients, serviceID)
			logger.GetLogger(m.ctx).Infof("Removed disconnected MCP client: %s", serviceID)
		}
	}
}

// GetActiveClients returns the number of active clients
func (m *MCPManager) GetActiveClients() int {
	m.clientsMu.RLock()
	defer m.clientsMu.RUnlock()

	count := 0
	for _, client := range m.clients {
		if client.IsConnected() {
			count++
		}
	}
	return count
}

// ListActiveServices returns IDs of services with active connections
func (m *MCPManager) ListActiveServices() []string {
	m.clientsMu.RLock()
	defer m.clientsMu.RUnlock()

	services := make([]string, 0, len(m.clients))
	for serviceID, client := range m.clients {
		if client.IsConnected() {
			services = append(services, serviceID)
		}
	}
	return services
}
